Skip to content

[Kernel] Add FP8 KV cache support to Triton MLA decode attention#34597

Merged
vllm-bot merged 1 commit intovllm-project:mainfrom
grimulkan:fp8-triton-mla
Mar 12, 2026
Merged

[Kernel] Add FP8 KV cache support to Triton MLA decode attention#34597
vllm-bot merged 1 commit intovllm-project:mainfrom
grimulkan:fp8-triton-mla

Conversation

@grimulkan
Copy link
Contributor

@grimulkan grimulkan commented Feb 16, 2026

Enable fp8/fp8_e4m3 KV cache for the Triton MLA attention backend, which is the only MLA backend available on sm120 GPUs.

  • Add fp8 and fp8_e4m3 to TritonMLABackend.supported_kv_cache_dtypes
  • Thread k_scale/v_scale through decode attention kernel launch path
  • Add FP8 dequant-on-load in both stage1 Triton kernels (MHA and grouped/MLA)
  • Set supports_quant_query_input=False for FP8 (BF16 queries + FP8 KV)
  • Add FP8-specific parametrized test cases

Purpose

Enable FP8 KV cache for MLA models on sm120 (Blackwell consumer GPUs). The Triton MLA backend is the only available MLA backend on sm120, but previously blocked FP8 with NotImplementedError.

Changes

  • Add "fp8" and "fp8_e4m3" to TritonMLABackend.supported_kv_cache_dtypes
  • Thread k_scale/v_scale through the decode attention kernel launch path
  • Add FP8 dequant-on-load pattern in both Triton kernels (grouped and non-grouped paths)
  • Set supports_quant_query_input=False for (BF16 queries + FP8 KV cache)
  • Add 16 FP8-specific parametrized test cases

Test Plan

  • Run pytest tests/kernels/attention/test_triton_decode_attention.py
  • Test inference using vllm serve on a compatible model using Triton MLA backend
  • End-to-end lm_eval (GSM8K)

Test Results

  • All 112 tests pass (96 existing BF16 + 16 new FP8) in tests/kernels/attention/test_triton_decode_attention.py
  • Tested on 16 x RTX 6000 Pro GPUs with Kimi K2.5 with DCP = 1 and fp8 KV cache
local-completions ({'model': 'Kimi-K2.5', 'tokenizer_backend': '/models/moonshotai-Kimi-K2.5', 'base_url': 'http://localhost:8000/v1/completions', 'num_concurrent': 500, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: 1

--attention-backend TRITON_MLA --kv-cache-dtype auto

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9378|±  |0.0067|
|     |       |strict-match    |     5|exact_match|↑  |0.9378|±  |0.0067|

--attention-backend TRITON_MLA --kv-cache-dtype fp8

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9393|±  |0.0066|
|     |       |strict-match    |     5|exact_match|↑  |0.9393|±  |0.0066|

~0.15 pts within expected tolerance. Normalized generation speed (ignoring the potential 2x higher concurrency with fp8) is about the same as bf16, which is to be expected in this approach.

Known limitations


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copilot AI review requested due to automatic review settings February 16, 2026 03:20
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify
Copy link

mergify bot commented Feb 16, 2026

Documentation preview: https://vllm--34597.org.readthedocs.build/en/34597/

@mergify mergify bot added documentation Improvements or additions to documentation v1 labels Feb 16, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request successfully enables FP8 KV cache support for the Triton MLA decode attention backend. The changes are well-implemented across the documentation, backend configuration, Triton kernels, and tests.

The key changes include:

  • Updating the TritonMLABackend to advertise support for fp8 and fp8_e4m3 KV cache data types.
  • Threading k_scale and v_scale through the decode attention call stack to the Triton kernels.
  • Implementing on-the-fly dequantization for FP8 tensors within the Triton kernels, which is efficient as it leverages compile-time checks.
  • Adding a comprehensive set of parameterized tests to validate the FP8 implementation against a BF16 reference, using appropriate precision tolerances for FP8 arithmetic.

The implementation is robust, and the changes are consistent and correct. The code quality is high, and I have no major concerns. This is a solid contribution to improving performance on newer GPU architectures.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request enables FP8 KV cache support for the Triton MLA (Multi-head Latent Attention) decode attention backend, which is the only MLA backend available on sm120 (Blackwell consumer) GPUs. The implementation uses Mode 1 FP8 (BF16 queries + FP8 KV cache) where FP8 tensors are dequantized on load inside the Triton kernels.

Changes:

  • Added FP8 and FP8_e4m3 to the list of supported KV cache data types for TritonMLABackend
  • Threaded k_scale and v_scale parameters through all decode attention kernel launch paths
  • Implemented FP8 dequantization in both stage1 Triton kernels (standard MHA and grouped/MLA paths)
  • Added comprehensive FP8-specific parametrized test cases with proper quantization and validation

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.

File Description
vllm/v1/attention/ops/triton_decode_attention.py Added k_scale/v_scale parameters and FP8 dequantization logic in both stage1 kernels; provides dummy 1.0 scales when None
vllm/v1/attention/backends/mla/triton_mla.py Added fp8/fp8_e4m3 to supported dtypes, set supports_quant_query_input=False for Mode 1, and passed layer scales to kernel
tests/kernels/attention/test_triton_decode_attention.py Added test_decode_attention_fp8 with 16 parametrized test cases covering various configurations
docs/design/attention_backends.md Updated TRITON_MLA backend documentation to reflect new KV cache dtype support

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@pavanimajety
Copy link
Collaborator

@grimulkan thanks for your contribution! Could you please add end to end lm_eval results for any of the standard models that can run on SM120?

@grimulkan
Copy link
Contributor Author

local-completions ({'model': 'Kimi-K2.5', 'tokenizer_backend': '/models/moonshotai-Kimi-K2.5', 'base_url': 'http://localhost:8000/v1/completions', 'num_concurrent': 500, 'tokenized_requests': False}), gen_kwargs: ({}), limit: None, num_fewshot: 5, batch_size: 1

--attention-backend TRITON_MLA --kv-cache-dtype auto

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9378|±  |0.0067|
|     |       |strict-match    |     5|exact_match|↑  |0.9378|±  |0.0067|

--attention-backend TRITON_MLA --kv-cache-dtype fp8

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9393|±  |0.0066|
|     |       |strict-match    |     5|exact_match|↑  |0.9393|±  |0.0066|

I think the accuracy drop (~0.15 pts) is well within expected tolerance. Normalized generation speed (ignoring the potential 2x higher concurrency with fp8) is about the same as bf16, which is to be expected in this approach.

Let me know if you need more tests like MMLU, etc. They take longer to run.

By the way, I noticed --kv-cache-dtype bfloat16 is not actually parsed correctly by some higher-level MLA dispatch (not related to this PR) and crashes on inference, probably a separate bug. However --kv-cache-dtype auto or not specifying the option, indeed maps internally to b16 (I can see it in the kv cache usage), so that's how I tested the baseline.

@voipmonitor
Copy link
Contributor

I'm confirming that this is working on 8x RTX PRO AMD Turin:

NCCL_P2P_LEVEL=SYS VLLM_LOG_STATS_INTERVAL=1 NCCL_GRAPH_FILE=/mnt/nccl_graph_opt.xml VLLM_TEST_FORCE_FP8_MARLIN=1 VLLM_MARLIN_USE_ATOMIC_ADD=1 VLARLIN_INPUT_DTYPE=fp8 vllm serve moonshotai/Kimi-K2.5 --served-model-name Kimi-K2.5 --trust-remote-code --host 0.0.0.0 --port 5000 --tensor-parallel-size 8 --pipeline-parallel-size 1 --enable-chunked-prefill --enable-prefix-caching --load-format fastsafetensors --tool-call-parser kimi_k2 --enable-auto-tool-choice --reasoning-parser kimi_k2 --async-scheduling --gpu-memory-utilization 0.93 --max-num-batched-tokens 4096 --mm-processor-cache-gb 0 --mm-encoder-tp-mode weights --language-model-only --attention-backend TRITON_MLA --kv-cache-dtype fp8

GPU KV cache size: 449,600 tokens
speed: 79tok/sec

when --decode-context-parallel-size 8 is used (more KV cache):
GPU KV cache size: 3,621,504 tokens

speed: 66tok/sec

@grimulkan
Copy link
Contributor Author

Cross-posting these results here:

Some speed/VRAM benchmarks on sm120.

Kimi K2.5 on RTX 6000 Pro** (native int4 experts, Marlin gemm, Triton MLA)

Cards TP DCP PP KV Cache Total KV Cache Space Generation Speed (@ 0 context)
8 8 8 1 fp8 3M tok 68 tok/s
8 8 1 1 fp8 380K tok 79 tok/s
8 8 8 1 bf16 1.5M tok 67 tok/s
8 8 1 1 bf16 190K tok 78 tok/s
16 16 16 1 fp8 20M tok 43 tok/s
16 16 1 1 fp8 1.25M tok 64 tok/s
16 16 16 1 bf16 10M tok 42 tok/s
16 16 1 1 bf16 638K tok 60 tok/s

All fp8 versions use this PR, and the DCP8 versions additionally use #34795. The unlocked KV cache savings are pretty huge.

NOTE: Likely this PR needs to be rebased & features merged if #33529 is merged before this one.

@pavanimajety pavanimajety added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 23, 2026
@pavanimajety
Copy link
Collaborator

@LucasWilkinson For review

ec-jt added a commit to ec-jt/vllm that referenced this pull request Mar 1, 2026
Cherry-picked from:
- PR vllm-project#34597: FP8 KV cache support for Triton MLA decode attention
- PR vllm-project#34795: Enable FP8 KV cache with Decode Context Parallel (DCP) for MLA

Changes:
- Add fp8/fp8_e4m3 to TritonMLABackend.supported_kv_cache_dtypes
- Thread k_scale/v_scale through decode attention kernel
- Add FP8 dequant-on-load in Triton kernels
- Enable DCP + FP8 KV cache combination
- Add gather_and_maybe_dequant_cache for FP8 DCP prefill path
@grimulkan
Copy link
Contributor Author

Rebased, no change in performance or functionality.

I experimented with supports_quant_query_input = True which allows for quantized Q all-gather with fp8 KV cache and DCP, but I found it was a net negative in performance on sm120. The quant-dequant compute load of Q does not beat the savings in DCP, from DCP = 1 to 16. Since sm120 is the main target of this kernel, decided to leave supports_quant_query_input = False and accept bf16 Q as input, while still allowing for fp8 KV and getting the benefits from that.

The non-sm120 attention backends currently all set supports_quant_query_input = True when fp8 KV cache is enabled, so this is a departure from that convention, but it is a valid one and supports_quant_query_input = False appears to be handled just fine by other operations.

No change - just recording here for posterity.

@mergify
Copy link

mergify bot commented Mar 9, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @grimulkan.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 9, 2026
@grimulkan
Copy link
Contributor Author

Rebased (only documentation conflict)

@mergify
Copy link

mergify bot commented Mar 11, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @grimulkan.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Collaborator

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, thanks for doing this! Just a few small comments

Enable fp8/fp8_e4m3 KV cache for the Triton MLA attention backend,
which is the only MLA backend available on sm120 GPUs.

- Add fp8 and fp8_e4m3 to TritonMLABackend.supported_kv_cache_dtypes
- Thread k_scale/v_scale through decode attention kernel launch path
- Add FP8 dequant-on-load in both stage1 Triton kernels (MHA and grouped/MLA)
- Set supports_quant_query_input=False for FP8 (BF16 queries + FP8 KV)
- Add FP8-specific parametrized test cases

Signed-off-by: grimulkan <grimulkan@gmail.com>
Copy link
Collaborator

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@MatthewBonanni MatthewBonanni enabled auto-merge (squash) March 12, 2026 14:39
@vllm-bot vllm-bot merged commit a1257fd into vllm-project:main Mar 12, 2026
50 of 52 checks passed
@andyluo7
Copy link

✅ ROCm MI300X Verification — All Tests Pass

Tested this PR on AMD Instinct MI300X (gfx942, ROCm 7.0.2) by patching into vllm/vllm-openai-rocm:v0.17.1.

Unit Tests: 16/16 PASSED ✅

test_decode_attention_fp8[1-16384-128-128-32-32-1025-3]   PASSED
test_decode_attention_fp8[1-16384-128-128-8-32-1025-3]    PASSED
test_decode_attention_fp8[1-16384-128-576-32-32-1025-3]   PASSED
test_decode_attention_fp8[1-16384-128-576-8-32-1025-3]    PASSED
test_decode_attention_fp8[1-16384-512-128-32-32-1025-3]   PASSED
test_decode_attention_fp8[1-16384-512-128-8-32-1025-3]    PASSED
test_decode_attention_fp8[1-16384-512-576-32-32-1025-3]   PASSED
test_decode_attention_fp8[1-16384-512-576-8-32-1025-3]    PASSED
test_decode_attention_fp8[16-16384-128-128-32-32-1025-3]  PASSED
test_decode_attention_fp8[16-16384-128-128-8-32-1025-3]   PASSED
test_decode_attention_fp8[16-16384-128-576-32-32-1025-3]  PASSED
test_decode_attention_fp8[16-16384-128-576-8-32-1025-3]   PASSED
test_decode_attention_fp8[16-16384-512-128-32-32-1025-3]  PASSED
test_decode_attention_fp8[16-16384-512-128-8-32-1025-3]   PASSED
test_decode_attention_fp8[16-16384-512-576-32-32-1025-3]  PASSED
test_decode_attention_fp8[16-16384-512-576-8-32-1025-3]   PASSED

16 passed in 12.81s

E2E Serving: DeepSeek-V2-Lite (MLA) with FP8 KV cache ✅

  • Config: --attention-backend TRITON_MLA --kv-cache-dtype fp8 --tensor-parallel-size 1
  • Throughput: ~72 tok/s @ concurrency=1 (256 output tokens)
  • KV cache capacity: 9.59M tokens (vs ~4.8M BF16 — ~2x memory savings)
  • Output quality: Correct, coherent generation confirmed

ROCm Notes

  • Kernel code is fully cross-platform — no changes needed for ROCm
  • k.dtype.is_fp8() correctly handles both float8_e4m3fn (CUDA) and float8_e4m3fnuz (ROCm)
  • FP8 dequant via simple float32 multiply works identically on both platforms
  • No tl.range() or other ROCm-incompatible Triton ops

Great work @grimulkan! 🎉

@grimulkan grimulkan deleted the fp8-triton-mla branch March 14, 2026 22:40
athrael-soju pushed a commit to athrael-soju/vllm that referenced this pull request Mar 16, 2026
…m-project#34597)

Signed-off-by: grimulkan <grimulkan@gmail.com>
Signed-off-by: Athrael Soju <athrael.soju@gmail.com>
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Mar 17, 2026
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants